{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0c6e8fa3-2900-4879-a610-3c038f8c8d23",
   "metadata": {},
   "source": [
    "# EinMix: universal toolkit for advanced MLP architectures\n",
    "\n",
    "Recent progress in MLP-based architectures demonstrated that *very specific* MLPs can compete with convnets and transformers (and even outperform them).\n",
    "\n",
    "EinMix allows writing such architectures in a more uniform and readable way."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc9a3913-5d5f-41da-8946-f17b3e537ec4",
   "metadata": {},
   "source": [
    "## EinMix — building block of MLPs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5f24a8f2-6680-4298-9736-081719c53f88",
   "metadata": {},
   "outputs": [],
   "source": [
    "from einops.layers.torch import EinMix as Mix"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eacd97e9-7ae2-4321-a3ad-3fb9b122fd81",
   "metadata": {},
   "source": [
    "Logic of EinMix is very close to the one of `einsum`. \n",
    "If you're not familiar with einsum, follow these guides first:\n",
    "\n",
    "- https://rockt.github.io/2018/04/30/einsum\n",
    "- https://towardsdatascience.com/einsum-an-underestimated-function-99ca96e2942e\n",
    "- https://theaisummer.com/einsum-attention/\n",
    "\n",
    "Einsum uniformly describes a number of operations, however `EinMix` is defined slightly differently.\n",
    "\n",
    "Here is a linear layer, a common block in sequence modelling (e.g. in NLP/speech), written with einsum\n",
    "```python\n",
    "weight = <...create tensor...>\n",
    "result = torch.einsum('tbc,cd->tbd', embeddings, weight)\n",
    "```\n",
    "\n",
    "EinMix counter-part is:\n",
    "```python\n",
    "mix_channels = Mix('t b c -> t b c_out', weight_shape='c c_out', ...)\n",
    "result = mix_channels(embeddings)\n",
    "```\n",
    "\n",
    "Main differences compared to plain `einsum` are:\n",
    "\n",
    "- layer takes care of the weight initialization & management hassle\n",
    "- weight is not in the comprehension\n",
    "\n",
    "We'll discuss other changes a bit later, now let's implement ResMLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d6983748-20d9-4825-a79f-a0641e58085c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# other stuff we use\n",
    "import torch\n",
    "from torch import nn\n",
    "from einops.layers.torch import Rearrange, Reduce"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "928fc321-7b41-4737-bbf4-644021c1c209",
   "metadata": {},
   "source": [
    "## ResMLP — original implementation\n",
    "\n",
    "Building blocks of ResMLP consist only of linear/affine layers and one activation (GELU). <br />\n",
    "Let's see how we can rewrite all of the components with Mix. \n",
    "\n",
    "We start from a reference code for ResMLP block published in the [paper](https://arxiv.org/pdf/2105.03404.pdf):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eed8edf4-0024-482a-bba2-03e84d4086ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# No norm layer\n",
    "class Affine(nn.Module):\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.alpha = nn.Parameter(torch.ones(dim))\n",
    "        self.beta = nn.Parameter(torch.zeros(dim))\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.alpha * x + self.beta\n",
    "    \n",
    "\n",
    "class Mlp(nn.Module):\n",
    "    def __init__(self, dim):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(dim, 4 * dim)\n",
    "        self.act = nn.GELU()\n",
    "        self.fc2 = nn.Linear(4 * dim, dim)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = self.act(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "    \n",
    "class ResMLP_Blocks(nn.Module):\n",
    "    def __init__(self, nb_patches, dim, layerscale_init):\n",
    "        super().__init__()\n",
    "        self.affine_1 = Affine(dim)\n",
    "        self.affine_2 = Affine(dim)\n",
    "        self.linear_patches = nn.Linear(nb_patches, nb_patches) #Linear layer on patches\n",
    "        self.mlp_channels = Mlp(dim) #MLP on channels\n",
    "        self.layerscale_1 = nn.Parameter(layerscale_init * torch.ones((dim))) # LayerScale\n",
    "        self.layerscale_2 = nn.Parameter(layerscale_init * torch.ones((dim))) # parameters\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res_1 = self.linear_patches(self.affine_1(x).transpose(1,2)).transpose(1,2)\n",
    "        x = x + self.layerscale_1 * res_1\n",
    "        res_2 = self.mlp_channels(self.affine_2(x))\n",
    "        x = x + self.layerscale_2 * res_2\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7496cf7b-a3e1-4648-b62b-30de26c2e718",
   "metadata": {},
   "source": [
    "## ResMLP &mdash; rewritten\n",
    "\n",
    "Code below is the result of first rewriting: \n",
    "- combination [transpose -> linear -> transpose back] got nicely packed into a single `EinMix` (`mix_patches`) <br />\n",
    "  `Mix('b t c -> b t0 c', weight_shape='t t0', bias_shape='t0', t=nb_patches, t0=nb_patches)` \n",
    "    - pattern `'b t c -> b t0 c'` tells that `b` and `c` are unperturbed, while tokens `t->t0` were mixed\n",
    "    - explicit parameter shapes are also quite insightful\n",
    "      \n",
    "- In new implementation affine layer is also handled by `EinMix`: <br />\n",
    "  `Mix('b t c -> b t c', weight_shape='c', bias_shape='c', c=dim)`\n",
    "  - from the pattern you can see that there is no mixing at all, only multiplication and shift\n",
    "  - multiplication and shift are defined by weight and bias - and those depend only on a channel\n",
    "  - thus affine transform is per-channel\n",
    "  \n",
    "- Linear layer is also handled by EinMix, the only difference compared to affine layer is absence of bias\n",
    "- We specified that input is 3d and order is `btc`, not `tbc` - this is not written explicitly in the original code\n",
    "\n",
    "The only step back that we had to do is change an initialization schema for EinMix for affine and linear layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a9cc8c91-d80a-4c35-a010-dbafad22b8c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Mlp(dim):\n",
    "    return nn.Sequential(\n",
    "        nn.Linear(dim, 4 * dim),\n",
    "        nn.GELU(),\n",
    "        nn.Linear(4 * dim, dim),\n",
    "    )\n",
    "\n",
    "def init(Mix_layer, scale=1.):\n",
    "    Mix_layer.weight.data[:] = scale\n",
    "    if Mix_layer.bias is not None:\n",
    "        Mix_layer.bias.data[:] = 0\n",
    "    return Mix_layer\n",
    "    \n",
    "class ResMLP_Blocks2(nn.Module):\n",
    "    def __init__(self, nb_patches, dim, layerscale_init):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.affine1 = init(Mix('b t c -> b t c', weight_shape='c', bias_shape='c', c=dim))\n",
    "        self.affine2 = init(Mix('b t c -> b t c', weight_shape='c', bias_shape='c', c=dim))\n",
    "        self.mix_patches = Mix('b t c -> b t0 c', weight_shape='t t0', bias_shape='t0', t=nb_patches, t0=nb_patches)\n",
    "        self.mlp_channels = Mlp(dim)\n",
    "        self.linear1 = init(Mix('b t c -> b t c', weight_shape='c', c=dim), scale=layerscale_init)\n",
    "        self.linear2 = init(Mix('b t c -> b t c', weight_shape='c', c=dim), scale=layerscale_init)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res1 = self.mix_patches(self.affine1(x))\n",
    "        x = x + self.linear1(res1)\n",
    "        res2 = self.mlp_channels(self.affine2(x))\n",
    "        x = x + self.linear2(res2)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "761d0d39-b041-4389-a19d-c3d512d64624",
   "metadata": {},
   "source": [
    "## ResMLP &mdash; rewritten more\n",
    "\n",
    "Since here in einops-land we care about code being easy to follow, let's make one more transformation.\n",
    "\n",
    "We group layers from both branches, and now the order of operations matches the order as they are written in the code.\n",
    "\n",
    "Could we go further? Actually, yes - `nn.Linear` layers can also be replaced by EinMix,\n",
    "however they are very organic here since first and last operations in `branch_channels` show components.\n",
    "\n",
    "Brevity of `nn.Linear` is benefitial when the context specifies tensor shapes.\n",
    "\n",
    "Other interesing observations:\n",
    "- hard to notice in the original code `nn.Linear` is preceded by a linear layer (thus latter is redundant or can be fused in the former)\n",
    "- hard to notice in the original code second `nn.Linear` is followed by an affine layer (thus latter is again redundant)\n",
    "\n",
    "Take time to reorganize your code. This may be quite insightful."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6fb50fac-1ef4-45a2-acdb-48f1050df7ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "def init(layer: Mix, scale=1.):\n",
    "    layer.weight.data[:] = scale\n",
    "    if layer.bias is not None:\n",
    "        layer.bias.data[:] = 0\n",
    "    return layer\n",
    "\n",
    "class ResMLP_Blocks3(nn.Module):\n",
    "    def __init__(self, nb_patches, dim, layerscale_init):\n",
    "        super().__init__()\n",
    "        self.branch_patches = nn.Sequential(\n",
    "            init(Mix('b t c -> b t c', weight_shape='c', c=dim), scale=layerscale_init),\n",
    "            Mix('b t c -> b t0 c', weight_shape='t t0', bias_shape='t0', t=nb_patches, t0=nb_patches),\n",
    "            init(Mix('b t c -> b t c', weight_shape='c', bias_shape='c', c=dim)),\n",
    "        )\n",
    "        \n",
    "        self.branch_channels = nn.Sequential(\n",
    "            init(Mix('b t c -> b t c', weight_shape='c', c=dim), scale=layerscale_init),\n",
    "            nn.Linear(dim, 4 * dim),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(4 * dim, dim),\n",
    "            init(Mix('b t c -> b t c', weight_shape='c', bias_shape='c', c=dim)),\n",
    "        )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = x + self.branch_patches(x)\n",
    "        x = x + self.branch_channels(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad7f343e-b81d-4b25-96bd-4c48ff33f9c8",
   "metadata": {},
   "source": [
    "## ResMLP &mdash; performance\n",
    "\n",
    "There is some fear of using einsum because historically it lagged in performance.\n",
    "\n",
    "Below we run a test and verify that performace didn't change after transition to `EinMix`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3987bbbc-ccc2-4374-8979-fd3fb29c4910",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28.1 ms ± 1.61 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "26.3 ms ± 620 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "25.9 ms ± 706 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "26.8 ms ± 2.99 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "25.9 ms ± 794 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "25.6 ms ± 723 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "x = torch.zeros([32, 128, 128])\n",
    "for layer in [\n",
    "    ResMLP_Blocks(128, dim=128, layerscale_init=1.),\n",
    "    ResMLP_Blocks2(128, dim=128, layerscale_init=1.),\n",
    "    ResMLP_Blocks3(128, dim=128, layerscale_init=1.),\n",
    "    # scripted versions    \n",
    "    torch.jit.script(ResMLP_Blocks(128, dim=128, layerscale_init=1.)),\n",
    "    torch.jit.script(ResMLP_Blocks2(128, dim=128, layerscale_init=1.)),\n",
    "    torch.jit.script(ResMLP_Blocks3(128, dim=128, layerscale_init=1.)),\n",
    "]:\n",
    "    %timeit -n 10 y = layer(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05428b35-5788-4e7f-867e-aa742d7215e7",
   "metadata": {},
   "source": [
    "## TokenMixer from MLPMixer — original code\n",
    "\n",
    "Let's now delve into MLPMixer. We start from pytorch [implementation](https://github.com/jaketae/mlp-mixer/blob/e7d68dfc31e94721724689e6ec90f05806b50124/mlp_mixer/core.py) by Jake Tae.\n",
    "\n",
    "We'll focus on two components of MLPMixer that don't exist in convnets. First component is TokenMixer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8f6a6073-44af-40ec-8154-02ae8370f749",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import functional as F\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, num_features, expansion_factor, dropout):\n",
    "        super().__init__()\n",
    "        num_hidden = num_features * expansion_factor\n",
    "        self.fc1 = nn.Linear(num_features, num_hidden)\n",
    "        self.dropout1 = nn.Dropout(dropout)\n",
    "        self.fc2 = nn.Linear(num_hidden, num_features)\n",
    "        self.dropout2 = nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.dropout1(F.gelu(self.fc1(x)))\n",
    "        x = self.dropout2(self.fc2(x))\n",
    "        return x\n",
    "\n",
    "\n",
    "class TokenMixer(nn.Module):\n",
    "    def __init__(self, num_features, num_patches, expansion_factor, dropout):\n",
    "        super().__init__()\n",
    "        self.norm = nn.LayerNorm(num_features)\n",
    "        self.mlp = MLP(num_patches, expansion_factor, dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x.shape == (batch_size, num_patches, num_features)\n",
    "        residual = x\n",
    "        x = self.norm(x)\n",
    "        x = x.transpose(1, 2)\n",
    "        # x.shape == (batch_size, num_features, num_patches)\n",
    "        x = self.mlp(x)\n",
    "        x = x.transpose(1, 2)\n",
    "        # x.shape == (batch_size, num_patches, num_features)\n",
    "        out = x + residual\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "362fa6ac-7146-4f7b-9e1d-0765486c26da",
   "metadata": {},
   "source": [
    "## TokenMixer from MLPMixer — reimplemented\n",
    "\n",
    "We can significantly reduce amount of code by using `EinMix`. \n",
    "\n",
    "- Main caveat addressed by original code is that `nn.Linear` mixes only last axis. `EinMix` can mix any axis.\n",
    "- Sequential structure is always preferred as it is easier to follow\n",
    "- Intentionally there is no residual connection in `TokenMixer`, because honestly it's not work of Mixer and should be done by caller\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f447c6c5-932f-4929-991a-f97b14776618",
   "metadata": {},
   "outputs": [],
   "source": [
    "def TokenMixer(num_features: int, n_patches: int, expansion_factor: int, dropout: float):\n",
    "    n_hidden = n_patches * expansion_factor\n",
    "    return nn.Sequential(\n",
    "        nn.LayerNorm(num_features),\n",
    "        Mix('b hw c -> b hid c', weight_shape='hw hid', bias_shape='hid', hw=n_patches, hidden=n_hidden),\n",
    "        nn.GELU(),\n",
    "        nn.Dropout(dropout),\n",
    "        Mix('b hid c -> b hw c', weight_shape='hid hw', bias_shape='hw',  hw=n_patches, hidden=n_hidden),\n",
    "        nn.Dropout(dropout),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19caacfc-a066-4727-918e-b3a98a190014",
   "metadata": {},
   "source": [
    "You may also like independent [implementation](https://github.com/lucidrains/mlp-mixer-pytorch/blob/main/mlp_mixer_pytorch/mlp_mixer_pytorch.py) of MLPMixer from Phil Wang. <br />\n",
    "Phil solves the issue by repurposing `nn.Conv1d` to mix on the second dimension. Hacky, but does the job\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5210dfd1-0195-458f-a8b9-819547123bad",
   "metadata": {},
   "source": [
    "## MLPMixer's patch embeddings — original\n",
    "\n",
    "Second interesting part of MLPMixer is derived from vision transformers.\n",
    "\n",
    "In the very beginning an image is split into patches, and each patch is linearly projected into embedding.\n",
    "\n",
    "I've taken the part of Jake's code responsible for embedding patches:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "24752920-006c-49e1-acb9-227543f84afb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_sizes(image_size, patch_size):\n",
    "    sqrt_num_patches, remainder = divmod(image_size, patch_size)\n",
    "    assert remainder == 0, \"`image_size` must be divisibe by `patch_size`\"\n",
    "    num_patches = sqrt_num_patches ** 2\n",
    "    return num_patches\n",
    "\n",
    "class Patcher(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        image_size=256,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_features=128,\n",
    "    ):\n",
    "        num_patches = check_sizes(image_size, patch_size)\n",
    "        super().__init__()\n",
    "        # per-patch fully-connected is equivalent to strided conv2d\n",
    "        self.patcher = nn.Conv2d(\n",
    "            in_channels, num_features, kernel_size=patch_size, stride=patch_size\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        patches = self.patcher(x)\n",
    "        batch_size, num_features, _, _ = patches.shape\n",
    "        patches = patches.permute(0, 2, 3, 1)\n",
    "        patches = patches.view(batch_size, -1, num_features)\n",
    "        \n",
    "        return patches"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5ecaac5-d5fe-48b7-9339-fe84fdccd9ac",
   "metadata": {},
   "source": [
    "## MLPMixer's patch embeddings — reimplemented\n",
    "\n",
    "`EinMix` does this in a single operation. This may require some training at first to understand.\n",
    "\n",
    "Let's go step-by-step:\n",
    "\n",
    "- `b c_in (h hp) (w wp) ->` - 4-dimensional input tensor (BCHW-ordered) is split into patches of shape `hp x wp`\n",
    "- `weight_shape='c_in hp wp c'`. Axes `c_in`, `hp` and `wp` are all absent in the output: three dimensional patch tensor was *mixed* to produce a vector of length `c`\n",
    "-  `-> b (h w) c` - output is 3-dimensional. All patches were reorganized from `h x w` grid to one-dimensional sequence of vectors\n",
    "\n",
    "\n",
    "We don't need to provide image_size beforehead, new implementation handles images of different dimensions as long as they can be divided into patches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2b786a62-a742-49d3-9e66-d02eb2f4e384",
   "metadata": {},
   "outputs": [],
   "source": [
    "def patcher(patch_size=16, in_channels=3, num_features=128):\n",
    "    return Mix('b c_in (h hp) (w wp) -> b (h w) c', weight_shape='c_in hp wp c', bias_shape='c',\n",
    "                  c=num_features, hp=patch_size, wp=patch_size, c_in=in_channels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a26b2f34-07cd-4fc2-bdf9-0043dd3451ea",
   "metadata": {},
   "source": [
    "## Vision Permutator\n",
    "\n",
    "As a third example we consider pytorch-like code from [ViP paper](https://arxiv.org/pdf/2106.12368.pdf).\n",
    "\n",
    "Vision permutator is only slightly more nuanced than previous models, because \n",
    "1. it operates on spatial dimensions separately, while MLPMixer and its friends just pack all spatial info into one axis. \n",
    "2. it splits channels into groups called 'segments'\n",
    "\n",
    "Paper provides pseudo-code, so I reworked that to complete module with minimal changes. Enjoy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2c94453b-bc78-4f06-8d51-f110e4b6501a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class WeightedPermuteMLP(nn.Module):\n",
    "    def __init__(self, H, W, C, S):\n",
    "        super().__init__()\n",
    "\n",
    "        self.proj_h = nn.Linear(H * S, H * S)\n",
    "        self.proj_w = nn.Linear(W * S, W * S)\n",
    "        self.proj_c = nn.Linear(C, C)\n",
    "        self.proj = nn.Linear(C, C)\n",
    "        self.S = S\n",
    "\n",
    "    def forward(self, x):\n",
    "        B, H, W, C = x.shape\n",
    "        S = self.S\n",
    "        N = C // S\n",
    "        x_h = x.reshape(B, H, W, N, S).permute(0, 3, 2, 1, 4).reshape(B, N, W, H*S)\n",
    "        x_h = self.proj_h(x_h).reshape(B, N, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)\n",
    "\n",
    "        x_w = x.reshape(B, H, W, N, S).permute(0, 1, 3, 2, 4).reshape(B, H, N, W*S)\n",
    "        x_w = self.proj_w(x_w).reshape(B, H, N, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)\n",
    "\n",
    "        x_c = self.proj_c(x)\n",
    "        \n",
    "        x = x_h + x_w + x_c\n",
    "        x = self.proj(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27a4aa53-e719-4079-973a-83b07d507c1f",
   "metadata": {},
   "source": [
    "That didn't look readable, right? \n",
    "\n",
    "This code is also very inflexible: code in the paper did not support batch dimension, and multiple changes were necessary to allow batch processing. <br />\n",
    "This process is fragile and easily can result in virtually uncatchable bugs.\n",
    "\n",
    "Now good news: each of these long method chains can be replaced with a single `EinMix` layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a0a699c0-f0b9-4e2c-8dc1-b33f5a7dfa1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class WeightedPermuteMLP_new(nn.Module):\n",
    "    def __init__(self, H, W, C, seg_len):\n",
    "        super().__init__()\n",
    "        assert C % seg_len == 0, f\"can't divide {C} into segments of length {seg_len}\"\n",
    "        self.mlp_c = Mix('b h w c -> b h w c0', weight_shape='c c0', bias_shape='c0', c=C, c0=C)\n",
    "        self.mlp_h = Mix('b h w (n c) -> b h0 w (n c0)', weight_shape='h c h0 c0', bias_shape='h0 c0',\n",
    "                            h=H, h0=H, c=seg_len, c0=seg_len)\n",
    "        self.mlp_w = Mix('b h w (n c) -> b h w0 (n c0)', weight_shape='w c w0 c0', bias_shape='w0 c0',\n",
    "                            w=W, w0=W, c=seg_len, c0=seg_len)\n",
    "        self.proj = nn.Linear(C, C)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.mlp_c(x) + self.mlp_h(x) + self.mlp_w(x)\n",
    "        return self.proj(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a9a7165-6c58-409f-b277-f4f7314e57fe",
   "metadata": {},
   "source": [
    "Great, now let's confirm that performance did not deteriorate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f375fcd7-8238-470d-aeaf-976892062911",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "90.5 ms ± 1.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "91.5 ms ± 616 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "91.8 ms ± 626 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
      "87.4 ms ± 3.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "x = torch.zeros([32, 32, 32, 128])\n",
    "\n",
    "for layer in [\n",
    "    WeightedPermuteMLP(H=32, W=32, C=128, S=4),\n",
    "    WeightedPermuteMLP_new(H=32, W=32, C=128, seg_len=4),\n",
    "    # scripted versions\n",
    "    torch.jit.script(WeightedPermuteMLP(H=32, W=32, C=128, S=4)),\n",
    "    torch.jit.script(WeightedPermuteMLP_new(H=32, W=32, C=128, seg_len=4)),\n",
    "]:    \n",
    "    %timeit -n 10 y = layer(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6694ade8-557d-4461-8f3a-4bd5f74dbea2",
   "metadata": {},
   "source": [
    "## Final remarks\n",
    "\n",
    "`EinMix` has an incredible potential: \n",
    "it helps with MLPs that don't fit into a limited 'mix all in the last axis' paradigm.\n",
    "\n",
    "However existing research is ... very limited, it does not cover real possibilities of densely connected architectures.\n",
    "\n",
    "Most of its *systematic* novelty is \"mix along spacial axes too\". \n",
    "But `EinMix` provides **an astonishing amount of other possibilities!**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1ee15e6-933a-4c3e-842e-e908a0aeee15",
   "metadata": {},
   "source": [
    "### Groups of mixers\n",
    "\n",
    "You can find two settings compared in the MLPMixer paper (Supplementary A1)\n",
    "```python\n",
    "'b hw c -> b hw_out c', weight_shape='hw hw_out'\n",
    "```\n",
    "and\n",
    "```python\n",
    "'b hw c -> b hw_out c', weight_shape='c hw hw_out'\n",
    "```\n",
    "While latter makes more sense (why mixing should work similarly for all channels?), the former performs better.\n",
    "\n",
    "So one more question is reasonable: what if channels are split into groups, and mixing is defined for each group?\n",
    "```python\n",
    "'b hw (group c) -> b hw_out (group c)', weight_shape='group hw hw_out'\n",
    "```\n",
    "Implementing such setting without einops is considerably harder.\n",
    "\n",
    "### Mixing within patch on a grid\n",
    "\n",
    "What if you make mixing 'local' in space? Completely doable:\n",
    "\n",
    "```python\n",
    "'b c (h hI) (w wI) -> b c (h hO) (w wO)', weight_shape='c hI wI hO wO'\n",
    "```\n",
    "\n",
    "We split tensor into patches of shape `hI wI` and mixed things within channel.\n",
    "  \n",
    "### Mixing in subgrids\n",
    "\n",
    "Ok, done with local mixing. How to collect information from the whole image? <br  />\n",
    "Well, you can again densely connect all the tokens, but all-to-all connection is too expensive.\n",
    "\n",
    "\n",
    "\n",
    "TODO need some image here to show sub-grids and information exhange. \n",
    "\n",
    "\n",
    "Here is EinMix-way: split the image into subgrids (each subgrid has steps `h` and `w`), and connect densely tokens within each subgrid\n",
    "\n",
    "```python\n",
    "'b c (hI h) (wI w) -> b c (hO h) (wO w)', weight_shape='c hI wI hO wO'\n",
    "```\n",
    "\n",
    "\n",
    "### Going deeper\n",
    "And that's very top of the iceberg. <br />\n",
    "\n",
    "- Want to mix part of axis? — No problems!\n",
    "- ... in a grid-like manner — Supported! \n",
    "- ... while mixing channels within group? — Welcome! \n",
    "- In 2d/3d/4d? — Sure!\n",
    "- I don't use pytorch — EinMix is available for multiple frameworks!\n",
    "\n",
    "Hopefully this guide helped you to find MLPs more interesting and intriguing. And simpler to experiment with."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
